'''
This code generates the embeddings of each sentence in the corpus and saves it into a pickle file data.pkl
For each sentence, it saves the document ID that it belongs to, the tesxt of the sentence and its embedding in a numpy array.
'''
import pandas as pd
import nltk
from nltk.tokenize import sent_tokenize
from sentence_transformers import SentenceTransformer
from tqdm import tqdm
import os
import torch

nltk.download('punkt')

data_dir = '<reviews_segment.pkl path>'
embed_dir = '<save directory>'
device = 'cuda' if torch.cuda.is_available() else 'cpu'

data = pd.read_pickle(data_dir)

all_data = []
file_path = embed_dir+f'data.pkl'
model = SentenceTransformer('all-MiniLM-L6-v2', device=device)

for idx in tqdm(range(len(data))):
    sample_id = data.loc[idx,'review_id'].strip("'")
    text = data.loc[idx,'review_text']
    sentences = sent_tokenize(text)
    embeddings = model.encode(sentences, convert_to_numpy=True, device=device)
    
    for sent, emb in zip(sentences, embeddings):
        all_data.append({
            "document_id": sample_id,
            "sentence": sent,
            "embedding": emb
        })
df = pd.DataFrame(all_data)
df.to_pickle(file_path)
